#!/bin/bash -x
#SBATCH --nodes=64
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --account=EUHPC_E03_068
#SBATCH --partition=boost_usr_prod
#SBATCH --threads-per-core=1
#SBATCH --time=24:00:00
#SBATCH --output=/leonardo_work/EUHPC_E03_068/DCFT_shared/outputs/mlfoundations-dev_hero-1_Qwen_Qwen2.5-7B-Instruct_open_thoughts_template/logs/slurm-%j.out
#SBATCH --exclude=lrdn[2663,2479,2075,3269,3145,2363,2269,1507,1927,1444,0880,1718,2358,1013,2139,0863,2976,2465,0736,1355,0843,2829,3136,1817,1539,2141,2490,3074,1058,2595,3132,2596,1506,2857,2689,2084,2404,2734,0701,2362,1474,1488,1509,1511,2122,2820,0999,1542,3279,2798,1240,3126,0806,1149,0772,1001,2654,2917,0860,2690,2581,1953,1647,2808,2598,2101,3006,1702,1235,2927,2997,3110,3282,3020,0032,2652,0973,1071,2898,2375,3277,1751,2753,2120,2755,1591,1308,2682,1658,1357,2792,1158]

set -e
# echo $IMAGE
# if image is set, use apptainer
if [ -z "$IMAGE" ]; then
    APPTAINER_ARGS=""
else
    APPTAINER_ARGS="
    singularity \
    run \
    --bind $SCRATCH:$SCRATCH \
    --bind $HOME:$HOME \
    --bind $WORK:$WORK \
    --bind $HF_HOME:$HF_HOME \
    --nv \
    $IMAGE
    "
fi
export NCCL_NET_GDR_LEVEL=PIX # Use GPU Direct RDMA when GPU and NIC are on the same PCI switch
export NCCL_SOCKET_IFNAME=ib0
export NCCL_IB_TIMEOUT=120
export NCCL_DEBUG=INFO
# export TRANSFORMERS_OFFLINE=1
export WANDB_MODE=offline
export WANDB_DIR="/leonardo_work/EUHPC_E03_068/DCFT_shared/outputs/mlfoundations-dev_hero-1_Qwen_Qwen2.5-7B-Instruct_open_thoughts_template"
export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export OMP_NUM_THREADS=1
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
# Allow communication over InfiniBand cells.
# Get IP for hostname.
export MASTER_ADDR="${MASTER_ADDR}"
export MASTER_PORT=20043
export NUM_NODES=64
export NUM_GPUS_PER_NODE=4
export NUM_GPUS=$((NUM_GPUS_PER_NODE*SLURM_NNODES))
DEEPSPEED_CONFIG_FILE=dcft/train/zero3_offload.json
OUTPUT_DIR="/leonardo_work/EUHPC_E03_068/DCFT_shared/outputs/mlfoundations-dev_hero-1_Qwen_Qwen2.5-7B-Instruct_open_thoughts_template"
mkdir -p $OUTPUT_DIR
TMP_DIR=$OUTPUT_DIR/tmp
mkdir -p $TMP_DIR

export APPTAINER_CACHEDIR="$WORK/shared/APPTAINER_CACHEDIR/"
export APPTAINER_TMPDIR="$WORK/shared/APPTAINER_TMPDIR/"
export IMAGE="$WORK/shared/container_images/dcft_latest.sif"
# export IMAGE="/leonardo_work/EUHPC_E03_068/shared/container_images/dcft_cuda12_1_latest.sif"
export HF_ENDPOINT="https://hf-mirror.com"
export HF_HOME=$WORK/tvu00001/hf_cache
export DCFT_PRIVATE="$WORK/tvu00001/dcft_private"
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

source $WORK/tvu00001/r1/bin/activate
module load cuda/12.1
module load gcc/12.2.0
module load nccl

ACCELERATE_CONFIG_FILE="$TMP_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
cat << EOT > $ACCELERATE_CONFIG_FILE
# WARNING: do not edit this file as this is an slurm-auto-generated file
compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: $DEEPSPEED_CONFIG_FILE
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config:
machine_rank: 0
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: $SLURM_NNODES
num_processes: $NUM_GPUS
use_cpu: false
EOT
LAUNCHER="python -u -m accelerate.commands.launch \
    --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
    --config_file $ACCELERATE_CONFIG_FILE \
    --machine_rank \$SLURM_PROCID \
    --role \$(hostname -s): --tee 3 \
    "
CONFIG=/leonardo_work/EUHPC_E03_068/DCFT_shared/dcft_private/dcft/train/configs/reasoning/mlfoundations-dev_hero-1_Qwen_Qwen2.5-7B-Instruct_open_thoughts_template.yaml
CMD="$LAUNCHER dcft/train/llamafactory/src/train.py $CONFIG"
SRUN_ARGS="
    --nodes=$NUM_NODES \
    --gres=gpu:$NUM_GPUS_PER_NODE \
    --cpus-per-task=$SLURM_CPUS_PER_TASK \
    --wait=60 \
    --kill-on-bad-exit=1 \
    --label \
    --jobid $SLURM_JOBID"
cd $DCFT_PRIVATE
srun $SRUN_ARGS $APPTAINER_ARGS bash -c "$CMD"